{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "import IPython\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import soundfile as sf\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "\n",
    "from nara_wpe.tf_wpe import wpe\n",
    "from nara_wpe.tf_wpe import online_wpe_step, get_power_online\n",
    "from nara_wpe.utils import stft, istft, get_stft_center_frequencies\n",
    "from nara_wpe import project_root"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stft_options = dict(\n",
    "    size=512,\n",
    "    shift=128,\n",
    "    window_length=None,\n",
    "    fading=True,\n",
    "    pad=True,\n",
    "    symmetric_window=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example with real audio recordings\n",
    "The iterations are dropped in contrast to the offline version. To use past observations the correlation matrix and the correlation vector are calculated recursively with a decaying window. $\\alpha$ is the decay factor."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "channels = 8\n",
    "sampling_rate = 16000\n",
    "delay = 3\n",
    "alpha=0.99\n",
    "taps = 10\n",
    "frequency_bins = stft_options['size'] // 2 + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Audio data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n",
    "signal_list = [\n",
    "    sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n",
    "    for d in range(channels)\n",
    "]\n",
    "y = np.stack(signal_list, axis=0)\n",
    "IPython.display.Audio(y[0], rate=sampling_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Online buffer\n",
    "For simplicity the STFT is performed before providing the frames.\n",
    "\n",
    "Shape: (frames, frequency bins, channels)\n",
    "\n",
    "frames: K+delay+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = stft(y, **stft_options).transpose(1, 2, 0)\n",
    "T, _, _ = Y.shape\n",
    "\n",
    "def aquire_framebuffer():\n",
    "    buffer = list(Y[:taps+delay+1, :, :])\n",
    "    for t in range(taps+delay+1, T):\n",
    "        yield np.array(buffer)\n",
    "        buffer.append(Y[t, :, :])\n",
    "        buffer.pop(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Non-iterative frame online approach\n",
    "A frame online example requires, that certain state variables are kept from frame to frame. That is the inverse correlation matrix $\\text{R}_{t, f}^{-1}$ which is stored in Q and initialized with an identity matrix, as well as filter coefficient matrix that is stored in G and initialized with zeros. \n",
    "\n",
    "Again for simplicity the ISTFT is applied in Numpy afterwards."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Z_list = []\n",
    "\n",
    "Q = np.stack([np.identity(channels * taps) for a in range(frequency_bins)])\n",
    "G = np.zeros((frequency_bins, channels * taps, channels))\n",
    "\n",
    "with tf.Session() as session:\n",
    "    Y_tf = tf.placeholder(tf.complex128, shape=(taps + delay + 1, frequency_bins, channels))\n",
    "    Q_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels * taps))\n",
    "    G_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels))\n",
    "    \n",
    "    results = online_wpe_step(Y_tf, get_power_online(tf.transpose(Y_tf, (1, 0, 2))), Q_tf, G_tf, alpha=alpha, taps=taps, delay=delay)\n",
    "    for Y_step in tqdm(aquire_framebuffer()):\n",
    "        feed_dict = {Y_tf: Y_step, Q_tf: Q, G_tf: G}\n",
    "        Z, Q, G = session.run(results, feed_dict)\n",
    "        Z_list.append(Z)\n",
    "\n",
    "Z_stacked = np.stack(Z_list)\n",
    "z = istft(np.asarray(Z_stacked).transpose(2, 0, 1), size=stft_options['size'], shift=stft_options['shift'])\n",
    "\n",
    "IPython.display.Audio(z[0], rate=sampling_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Power spectrum\n",
    "Before and after applying WPE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 8))\n",
    "im1 = ax1.imshow(20 * np.log10(np.abs(Y[200:400, :, 0])).T, origin='lower')\n",
    "ax1.set_xlabel('')\n",
    "_ = ax1.set_title('reverberated')\n",
    "im2 = ax2.imshow(20 * np.log10(np.abs(Z_stacked[200:400, :, 0])).T, origin='lower')\n",
    "_ = ax2.set_title('dereverberated')\n",
    "cb = fig.colorbar(im1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
